-
Notifications
You must be signed in to change notification settings - Fork 1
Ft/ensemble changes #115
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Ft/ensemble changes #115
Conversation
… uses target synthetic data.
…rent experimental setups
* Added testing several targets on multiple gpus * Added a comment
| challenge_data_path: ${target_model.target_model_directory}/${target_model.target_model_name}/challenge_with_id.csv | ||
| challenge_label_path: ${target_model.target_model_directory}/${target_model.target_model_name}/challenge_label.csv | ||
|
|
||
| target_attack_artifact_dir: ${base_experiment_dir}/target_${target_model.target_model_id}_attack_artifacts/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This directory was extra and can be removed.
| @@ -1,34 +1,36 @@ | |||
| # Ensemble experiment configuration | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current single config is hard to understand because it mixes many variables and data paths with unclear names inherited from the original attack code. Splitting it into multiple pipeline‑specific configs would improve clarity and maintainability, even if it adds some overhead. Alternatively, improving variable naming within one config could be helpful.
| ) | ||
|
|
||
| population.append(df_real) | ||
| population.append(df_real) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a bug! Thank you for catching this, Sara!
| # Load the required dataframes for shadow model training. | ||
| # For shadow model training we need master_challenge_train and population data. | ||
| # Master challenge is the main training (or fine-tuning) data for the shadow models. | ||
| df_master_challenge_train = load_dataframe( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of loading the data here, it is passed to the function.
| f"Fine-tuned model {model_id} generated {len(train_result.synthetic_data)} synthetic samples.", | ||
| ) | ||
| attack_data["fine_tuned_results"].append(train_result) | ||
| attack_data["fine_tuned_results"].append(train_result.synthetic_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only need to save the synthetic data.
| target_shadow_models_output_path: ${target_model.target_attack_artifact_dir}/tabddpm_${target_model.target_model_id}_shadows_dir | ||
| target_shadow_models_output_path: ${base_experiment_dir}/test_all_targets # Sub-directory to store test shadows and results | ||
| attack_probabilities_result_path: ${target_model.target_shadow_models_output_path}/test_probabilities/attack_model_${target_model.target_model_id}_proba | ||
| attack_rmia_shadow_training_data_choice: "combined" # Options: "combined", "only_challenge", "only_train". This determines which data to use for training RMIA attack model in testing phase. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a new config variable. You can read more about the options in select_challenge_data_for_training()'s docstring.
|
|
||
|
|
||
| @hydra.main(config_path="configs", config_name="experiment_config", version_base=None) | ||
| def run_metaclassifier_testing( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function executes the attack on a single target model (target_model.target_model_id). However, all target models within an experiment can share the same trained RMIA shadow models. Our current workflow is to train the RMIA shadow models using one target model, and then run tests on all remaining targets in parallel using run_test.sh. Each of these targets simply loads the previously trained RMIA shadow models. This approach was originally designed to speed up testing.
Later, we realized that the main runtime bottleneck (testing phase) is actually the RMIA shadow‑model training step. As a result, a potential refactoring improvement would be to modify this function so that it trains the RMIA shadow models once and then sequentially tests a set of target models within a single function call. This can simplify the testing process with potentially little to no loss in efficiency.
📝 WalkthroughWalkthroughThis pull request refactors the ensemble attack pipeline from 20k to 10k data, restructuring the data collection workflow to use pre-loaded population and challenge datasets passed as parameters rather than loaded internally, and optimizing distance computations through batched and multiprocessing-enabled Gower calculations. The changes include configuration updates with expanded data splits, removal of on-disk data loading in favor of externally managed DataFrames, refactoring of internal data representations to store synthetic data directly instead of TrainingResult objects, and comprehensive test updates reflecting the new data handling approach. Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🤖 Fix all issues with AI agents
In @examples/ensemble_attack/configs/experiment_config.yaml:
- Around line 109-110: In the attack_success_computation block update the
target_ids_to_test list to remove the duplicate 26 and insert the missing 27 so
the sequence is correct; locate the target_ids_to_test entry and replace the
duplicated 26 with 27 (ensuring each target ID appears once and includes 27).
- Around line 2-4: Fix the typo in the top comment of the YAML config: replace
"tets_attack_model.py" with "test_attack_model.py" in the first comment line so
the referenced test script name is correct; update the comment string that
currently reads "run_attack.py and tets_attack_model.py" to "run_attack.py and
test_attack_model.py".
In @examples/ensemble_attack/real_data_collection.py:
- Around line 182-195: Remove the duplicated block that repeats the
population_splits/challenge_splits defaults and the redundant save_dir.mkdir
call in examples/ensemble_attack/real_data_collection.py: keep the first
occurrence that sets population_splits = ["train"] and challenge_splits =
["train", "dev", "final"] and remove the second duplicate block (the repeated if
population_splits is None / if challenge_splits is None and the extra
save_dir.mkdir). Ensure only one mkdir(save_dir) and one defaults assignment
remain (so functions or callers relying on those variables still see the
intended defaults).
In @examples/ensemble_attack/run_metaclassifier_training.py:
- Around line 89-92: The log call that prints the reference population path
contains a malformed f-string ("f{config.data_paths.population_path}") so the
literal text "f{...}" will be logged; in the logging statement that references
config.data_paths.population_path (the log(...) call near where df_reference is
used in run_metaclassifier_training), remove the stray leading "f" before the
curly brace so the f-string interpolates the actual path value (i.e., ensure the
f-string only prefixes the whole string once and references
config.data_paths.population_path normally).
- Around line 28-29: The docstring for the metaclassifier training entry
duplicates the parameter description for target_model_synthetic_path; remove the
redundant entry so target_model_synthetic_path appears only once in the
function/module docstring (update the docstring block that lists parameters to
keep a single clear description of target_model_synthetic_path and delete the
duplicate paragraph).
In @examples/ensemble_attack/run_shadow_model_training.py:
- Around line 106-110: The code contains a duplicate assertion: remove the
redundant assertion that checks "trans_id" in df_challenge_train.columns (the
second occurrence that repeats the check at the start of the block) so that you
only assert once for df_challenge_train and keep the existing assertion for
df_population_with_challenge; locate the repeated line referencing
df_challenge_train.columns and delete it.
In @examples/ensemble_attack/test_attack_model.py:
- Around line 96-97: The assignment to shadow_model_paths from
run_shadow_model_training(...) is being discarded by the immediate overwrite
from config.shadow_training.final_shadow_models_path; remove the second
assignment so the returned paths from run_shadow_model_training are used (i.e.,
delete the line that sets shadow_model_paths = [Path(path) for path in
config.shadow_training.final_shadow_models_path]) and ensure any downstream
logic uses the shadow_model_paths variable returned by
run_shadow_model_training; alternatively, if you truly intend to use the config
paths, remove the run_shadow_model_training call instead, but prefer keeping
run_shadow_model_training's return value.
🧹 Nitpick comments (3)
src/midst_toolkit/attacks/ensemble/rmia/rmia_calculation.py (1)
350-388: Consider makingn_jobsconfigurable.The
n_jobs=4is hardcoded in multiple calls toget_rmia_gower. Consider passing this as a parameter tocalculate_rmia_signalsfor flexibility across different hardware configurations.tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
68-71: Consider reordering assertions for clearer error messages.If
synthetic_datawereNone, the type assertion would fail with a confusing message. Consider checking forNonefirst.Suggested order
for synthetic_data in shadow_data["fine_tuned_results"]: - assert type(synthetic_data) is pd.DataFrame assert synthetic_data is not None + assert type(synthetic_data) is pd.DataFrame assert len(synthetic_data) == 5examples/ensemble_attack/test_attack_model.py (1)
162-168: Consider making the data split configurable.The hardcoded
data_splits=["test"]with the comment suggesting manual changes for different experiments could be error-prone. Consider extracting this to a config parameter.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
examples/ensemble_attack/configs/experiment_config.yamlexamples/ensemble_attack/real_data_collection.pyexamples/ensemble_attack/run_attack.pyexamples/ensemble_attack/run_metaclassifier_training.pyexamples/ensemble_attack/run_shadow_model_training.pyexamples/ensemble_attack/run_train.shexamples/ensemble_attack/test_attack_model.pysrc/midst_toolkit/attacks/ensemble/rmia/rmia_calculation.pysrc/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.pytests/integration/attacks/ensemble/test_shadow_model_training.pytests/unit/attacks/ensemble/test_rmia.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-12-11T16:08:49.024Z
Learnt from: lotif
Repo: VectorInstitute/midst-toolkit PR: 107
File: examples/gan/synthesize.py:1-47
Timestamp: 2025-12-11T16:08:49.024Z
Learning: When using SDV (version >= 1.18.0), prefer loading a saved CTGANSynthesizer with CTGANSynthesizer.load(filepath) instead of sdv.utils.load_synthesizer(). This applies to Python code across the repo (e.g., any script that loads a CTGANSynthesizer). Ensure the SDV version is >= 1.18.0 before using CTGANSynthesizer.load, and fall back to sdv.utils.load_synthesizer() only if a compatible alternative is required.
Applied to files:
src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.pyexamples/ensemble_attack/run_shadow_model_training.pyexamples/ensemble_attack/real_data_collection.pytests/integration/attacks/ensemble/test_shadow_model_training.pytests/unit/attacks/ensemble/test_rmia.pyexamples/ensemble_attack/run_metaclassifier_training.pysrc/midst_toolkit/attacks/ensemble/rmia/rmia_calculation.pyexamples/ensemble_attack/test_attack_model.pyexamples/ensemble_attack/run_attack.py
🧬 Code graph analysis (5)
src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py (1)
src/midst_toolkit/attacks/ensemble/shadow_model_utils.py (2)
fine_tune_tabddpm_and_synthesize(158-248)TrainingResult(26-33)
examples/ensemble_attack/run_shadow_model_training.py (1)
src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py (1)
train_three_sets_of_shadow_models(309-442)
tests/unit/attacks/ensemble/test_rmia.py (1)
src/midst_toolkit/attacks/ensemble/rmia/rmia_calculation.py (2)
Key(23-25)get_rmia_gower(136-215)
examples/ensemble_attack/test_attack_model.py (3)
examples/ensemble_attack/real_data_collection.py (2)
AttackType(17-31)collect_midst_data(101-142)src/midst_toolkit/attacks/ensemble/blending.py (1)
MetaClassifierType(21-23)src/midst_toolkit/attacks/ensemble/data_utils.py (1)
load_dataframe(31-52)
examples/ensemble_attack/run_attack.py (3)
src/midst_toolkit/attacks/ensemble/data_utils.py (1)
load_dataframe(31-52)examples/ensemble_attack/real_data_collection.py (1)
collect_population_data_ensemble(145-256)examples/ensemble_attack/run_shadow_model_training.py (1)
run_shadow_model_training(83-134)
🪛 Ruff (0.14.10)
examples/ensemble_attack/test_attack_model.py
135-135: pickle and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
(S301)
224-226: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (44)
examples/ensemble_attack/run_metaclassifier_training.py (2)
68-68: LGTM!Good addition of logging after loading shadow model data to aid debugging and traceability.
77-80: LGTM!Helpful logging addition for tracing data loading with size information.
examples/ensemble_attack/real_data_collection.py (7)
7-7: LGTM!Appropriate imports added for logging functionality.
Also applies to: 14-14
64-66: LGTM!Documentation improvements clarify the meaning of
data_splitparameter.Also applies to: 82-88
110-121: LGTM!Enhanced documentation for
data_splitsparameter improves clarity.
140-142: LGTM!Simplified append logic looks correct.
149-149: LGTM!Updated function signature and comprehensive docstring for the new
original_repo_populationparameter.Also applies to: 156-174
198-213: LGTM!Good addition of logging for population data collection and the concatenation with
original_repo_population.
231-231: LGTM!Helpful logging for challenge data collection with splits information.
src/midst_toolkit/attacks/ensemble/rmia/rmia_calculation.py (6)
7-11: LGTM!Good additions:
Sequencefor type hints,Poolfor multiprocessing, andFloatDTypetype alias for clearer typing.Also applies to: 20-21
28-70: LGTM!Well-implemented batched Gower distance computation. Pre-allocating the output matrix and processing in chunks is an effective pattern for reducing peak memory usage. The batch size of 5000 used at the call site (line 131) is reasonable.
73-133: LGTM!Good implementation of a multiprocessing-friendly wrapper. Creating a copy of
df_synthetic(line 118) before modifications is correct to avoid mutation issues. The tuple-based argument passing is appropriate forPool.imap_unordered.
136-215: LGTM!Solid multiprocessing implementation. Using
imap_unorderedwith index tracking in a dict, then reconstructing original order is the correct pattern for parallel processing where order matters. The fallback to sequential processing whenuse_multiprocessing=Falseis useful for debugging and testing.
334-340: LGTM!Updated to work with DataFrames directly instead of TrainingResult objects, aligning with the refactored data storage approach.
394-396: LGTM!Logging additions and updates to use
shadow_training_data_idsfor mask creation are consistent with the refactored data handling approach.Also applies to: 406-406, 422-427, 455-455
src/midst_toolkit/attacks/ensemble/rmia/shadow_model_training.py (3)
187-187: LGTM!Good refactor to store only
synthetic_data(DataFrame) instead of the fullTrainingResultobject. This reduces memory usage and pickle file sizes. The assertion on line 182 ensures the data is not None before appending.
299-299: LGTM!Consistent with the change in
train_fine_tuned_shadow_models. The assertion on line 293 validates the data before appending.
441-442: LGTM!Trailing whitespace change has no functional impact.
examples/ensemble_attack/run_train.sh (3)
6-12: LGTM, but verify GPU availability.Resource increases align with the ensemble attack requirements. The specific
gpu:a100:1request may fail if A100 GPUs are unavailable on the cluster. Consider using a more generic GPU request or documenting the A100 requirement.
15-15: LGTM!Useful memory logging for debugging resource allocation.
24-24: LGTM!Config name updated to target the 10k data experiment configuration.
tests/integration/attacks/ensemble/test_shadow_model_training.py (1)
107-109: LGTM!Test assertions correctly updated to validate DataFrame type and expected length. Note: this test doesn't check for
None, but the production code assertion at line 293 inshadow_model_training.pyensures this won't happen.examples/ensemble_attack/run_attack.py (3)
14-14: LGTM!Import added to support the new data loading functionality.
27-34: LGTM!Good addition of loading the original repository population data and passing it to
collect_population_data_ensemble. The comment clearly explains why this is needed (to provide a larger population dataset for DOMIAS).Also applies to: 41-41
81-85: LGTM!Correctly loads the master challenge training data and passes it to the updated
run_shadow_model_trainingfunction signature.examples/ensemble_attack/run_shadow_model_training.py (3)
5-5: LGTM!The added pandas import is necessary to support the new DataFrame type hint in the function signature.
83-95: LGTM!Good refactor to accept
df_challenge_trainas a parameter instead of loading from disk. This aligns with the broader PR goal of passing DataFrames directly rather than loading internally, improving testability and flexibility.
114-128: LGTM!The updated call to
train_three_sets_of_shadow_modelscorrectly passesdf_challenge_trainasmaster_challenge_data, which aligns with the function's signature inshadow_model_training.py.tests/unit/attacks/ensemble/test_rmia.py (6)
45-53: LGTM!The model_data fixture correctly stores DataFrames directly instead of wrapping them in mock objects, aligning with the refactored API that expects
list[pd.DataFrame]formodel_data.
77-92: LGTM!The rmia_signal_data fixture is correctly updated to store DataFrames directly in
fine_tuned_resultsandtrained_resultslists, consistent with the new data model.
151-169: LGTM!Good test updates:
- Using
list(...)to extract DataFrames from model_datause_multiprocessing=Falseensures mocks work correctly in the main processdtype=np.float32in expected arrays matches the function's default dtype
173-179: LGTM!Correctly accessing DataFrames directly from model_data instead of through
.synthetic_dataattribute.
181-216: LGTM!The sampling test is well-updated with:
- Direct DataFrame access
use_multiprocessing=Falsefor deterministic behavior- Enhanced
assert_frame_equalwith descriptiveobjparameter for better debugging
218-237: LGTM!The missing categorical column test correctly uses
list(...)to extract DataFrames from the fixture.examples/ensemble_attack/configs/experiment_config.yaml (2)
20-22: LGTM!Good addition of
attack_rmia_shadow_training_data_choiceoption with clear options documented in the comment. This provides flexibility for controlling RMIA shadow training dataset selection.
48-55: LGTM!Good expansion of
challenge_splitsandfolder_rangesto support the test phase data collection. The ranges are clearly structured.examples/ensemble_attack/test_attack_model.py (8)
22-45: LGTM!Good extraction of result saving logic into a dedicated helper function. The function handles both saving probabilities and optionally saving the TPR@FPR=0.1 score.
47-77: LGTM!Clean helper function for extracting and dropping ID columns. Good use of assertions for validation.
114-142: LGTM!Good implementation of
load_trained_rmia_shadows_for_test_phase. The function correctly checks existence of all models before loading and returns early with an empty list if any model is missing.Regarding the static analysis hint about pickle (S301): this is internal research tooling loading models from known paths, so the security risk is acceptable in this context.
145-184: LGTM!Well-structured helper function for collecting challenge and train data with clear logging.
187-228: LGTM!Good implementation of
select_challenge_data_for_trainingwith clear documentation of the three options. The ValueError for invalid choices provides a helpful error message.Regarding the static analysis hint (TRY003): the detailed error message is appropriate here as it helps users understand the valid options.
311-315: LGTM!Good defensive handling to limit synthetic data size based on config. Using
.head()preserves consistency.
325-334: LGTM!Good optimization to reuse existing shadow models when available, avoiding redundant training.
353-357: LGTM!Good addition of loading reference population data for DOMIAS signals computation.
PR Type
[Fix | Documentation]
Short Description
Clickup Ticket(s): https://app.clickup.com/t/868gy123e
This PR introduces several improvements to the Ensemble Attack code and fixes based on issues we found during experimentation.
attack_rmia_shadow_training_data_choice, is added.Several other minor fixes and improvements to the documentation are also included in this PR.
Tests Added
Existing tests are updated.